feat: support ring attention with arbitrary KV sequence lengths#13545
feat: support ring attention with arbitrary KV sequence lengths#13545songh11 wants to merge 4 commits intohuggingface:mainfrom
Conversation
|
Thanks for your PR. Could you also provide some visual output examples and comparison with other backends we support (latency, memory overhead, etc.)? |
@sayakpaul Thanks for the review! I just pushed two follow-up commits:
FLUX.1-dev benchmarks on 4 × RTX 4090 (48GB)
Visual outputsI'll attach side-by-side generated images (same prompt + seed: single-GPU baseline vs. |
|
Thanks for the updates -- they are very comprehensive. It seems like ulysses_anything is still the best but users who might prefer ring would want to use ring_anything for non-divisible shapes? Is that the motivation here? |
Thanks for the review! You're spot on — the primary motivation is to provide a seamless path for users on a ring topology to handle non-divisible sequence lengths, similar to how ulysses_anything works. While ulysses/ulysses_anything is often the go-to, I believe ring_anything is a valuable addition for a few key reasons:
So while ulysses/ulysses_anything remains the recommended choice for most, while ring_anything covers head-constrained or ring-first setups—focusing on correctness today, with clear headroom for P2P overlap in the next step. |
|
Your reasoning is spot on! Let's include them in the docs too. |
|
@claude could you do a review here? |
|
Claude finished @sayakpaul's task in 2m 29s —— View job Review of PR #13545: Ring Anything Attention
Overall this is a well-structured PR that cleanly mirrors the existing patterns. The core algorithm (pad → all-gather → slice → online-softmax merge) is correct, and the config validation + hook routing are consistent with Duplicate hook branches should be merged
if self.parallel_config.ulysses_anything:
return PartitionAnythingSharder.shard_anything(...)
elif self.parallel_config.ring_anything:
return PartitionAnythingSharder.shard_anything(...)Since the bodies are identical, these should be merged: if self.parallel_config.ulysses_anything or self.parallel_config.ring_anything:
return PartitionAnythingSharder.shard_anything(...)Same for the Incomplete docstring comment for
|
|
@sayakpaul Thanks! Addressed the feedback by merging duplicate hooks, adding the exclusivity check, and polishing comments. Pushed in the latest commit. |
sayakpaul
left a comment
There was a problem hiding this comment.
Thanks for the updates. I left some further comments.
|
|
||
| ### Ring Anything Attention | ||
|
|
||
| The default Ring Attention requires the sequence length of hidden states to be evenly divisible across the ring degree. Ring Anything Attention is a variant of Ring Attention that supports arbitrary (non-evenly divisible) sequence lengths. It pads each rank's local KV to the global maximum sequence length, all-gathers the padded KV buffer, and slices back to each rank's true length before running attention. |
There was a problem hiding this comment.
Do we want to supplement a link to Ring Anything Attention?
| pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ring_degree=2, ring_anything=True)) | ||
| ``` | ||
|
|
||
| > [!TIP] To avoid multiple forced CUDA sync caused by H2D and D2H transfers, please add the **gloo** backend in `init_process_group`. This will significantly reduce communication latency. |
There was a problem hiding this comment.
Use a code snippet to make it clear how that should be incorporated in the code and hyperlink to anything docs that discusses this benefit of "gloo".
| | ulysses_anything | 253.16 | 3.95 | 33.75 | 1008x1008 | | ||
| | ring_anything | 335.57 | 2.98 | 33.75 | 1008x1008 | | ||
|
|
||
| From the above table, Ring Anything Attention offers compatibility with arbitrary sequence lengths while maintaining performance comparable to the standard Ring Attention. |
There was a problem hiding this comment.
I would also discuss the limitations mitigated by Ring Anything:
#13545 (comment)
| _parallel_config: "ParallelConfig" | None = None, | ||
| ): | ||
| # Ring attention for arbitrary sequence lengths. | ||
| if attn_mask is not None: |
There was a problem hiding this comment.
Seems like a pretty big limitation no? This would make it incompatible with models like QwenImage, right?
| return t | ||
| pad_shape = list(t.shape) | ||
| pad_shape[1] = pad_len | ||
| return torch.cat([t, torch.zeros(pad_shape, dtype=t.dtype, device=t.device)], dim=1) |
There was a problem hiding this comment.
We should be able to directly do: torch.zeroes(t.shape, ...).
| pad_shape[1] = pad_len | ||
| return torch.cat([t, torch.zeros(pad_shape, dtype=t.dtype, device=t.device)], dim=1) | ||
|
|
||
| key_padded = pad_to_s_max(key) |
There was a problem hiding this comment.
Would add a small explainer comment.
| for i in range(world_size): | ||
| if i > 0: | ||
| true_seq_len = all_kv_seq_lens[next_rank] | ||
| kv = kv_buffer[next_rank] | ||
| # Reshape to padded shape, then slice to true sequence length | ||
| key = kv[:kv_padded_numel].reshape_as(key_padded)[:, :true_seq_len] | ||
| value = kv[kv_padded_numel:].reshape_as(value_padded)[:, :true_seq_len] | ||
| next_rank = (next_rank + 1) % world_size | ||
| else: | ||
| # i == 0: use local (unpadded) key/value | ||
| key = key_padded[:, :kv_seq_len] | ||
| value = value_padded[:, :kv_seq_len] | ||
|
|
||
| out, lse = forward_op( | ||
| ctx, | ||
| query, | ||
| key, | ||
| value, | ||
| attn_mask, | ||
| dropout_p, | ||
| is_causal, | ||
| scale, | ||
| enable_gqa, | ||
| True, |
There was a problem hiding this comment.
@claude can we use torch.where here for a better conditional flow graph?







What does this PR do?
Adds a new "Ring Anything" context-parallel attention mode that supports arbitrary
(non-evenly divisible) KV sequence lengths across ring-degree workers.
Motivation
Existing
TemplatedRingAttentionrequires KV to be equipartitioned across ranks,which is impractical for real-world workloads where per-rank sequence lengths can
differ (e.g., variable-length prompts, packed batches, token pruning). This PR
mirrors the existing
ulysses_anythingdesign but applies it to the ring path.Changes
ContextParallelConfig: addring_anythingflag with validation(
ring_degree > 1andulysses_degree == 1).TemplatedRingAnythingAttention: new autograd Function that_templated_context_parallel_attention: dispatch to the new class whenring_anythingis enabled.ContextParallelSplitHook: route throughPartitionAnythingSharder.shard_anythingwhen
ring_anythingis set.Reproducible example
Launch
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@yiyixuxu @asomoza @sayakpaul